"""Plot hist of model types split features."""
from emodel_generalisation.mcmc import load_chains
import numpy as np
import matplotlib.pyplot as plt

if __name__ == "__main__":
    mcmc_df = load_chains("../../mcmc_run/run_df.csv", base_path="../../mcmc_run")
    mcmc_df = mcmc_df[mcmc_df.cost < 2.0]

    mask_tonic = mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.tonic_after_burst"] > 0.0
    mask_runaway = (
        ~mask_tonic
        & (mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.burst_runaway"] < 0.05)
        & (mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.time_to_last_spike"] > 20000)
    )
    mask_ecel = (
        ~mask_tonic
        & ~mask_runaway
        & (mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.all_burst_number"] < 5)
    )
    mask_spp = (
        ~mask_tonic
        & ~mask_runaway
        & (mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.all_burst_number"] >= 5)
    )

    log = True
    plt.figure(figsize=(5, 3))
    plt.hist(
        mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.time_to_last_spike"],
        histtype="step",
        bins=100,
        range=(5500, 25000),
        log=log,
        label="all",
        lw=2,
        color="0.5",
    )
    plt.hist(
        mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.time_to_last_spike"][mask_runaway],
        histtype="step",
        bins=100,
        range=(5500, 25000),
        log=log,
        label="runaway",
        color="C2",
    )
    plt.axvline(20000, c="r", label="runaway threshold")
    plt.legend()
    plt.xlabel("time_to_last_spike")
    plt.tight_layout()
    plt.savefig("hist_time_to_last_spike.pdf")

    plt.figure(figsize=(5, 3))
    plt.hist(
        np.clip(mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.burst_runaway"], -0.5, 2),
        histtype="step",
        bins=200,
        range=(-0.5, 2),
        log=log,
        lw=2,
        label="all",
        color="0.5",
    )
    plt.hist(
        np.clip(
            mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.burst_runaway"][mask_runaway],
            -0.5,
            2,
        ),
        histtype="step",
        bins=200,
        range=(-0.5, 2),
        log=log,
        label="runaway",
        color="C2",
    )
    plt.axvline(0.05, c="r", label="runaway threshold")
    plt.legend()
    plt.xlabel("burst_runaway")
    plt.tight_layout()
    plt.savefig("hist_burst_runaway.pdf")

    plt.figure(figsize=(5, 3))
    plt.hist(
        mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.all_burst_number"],
        histtype="step",
        bins=100,
        lw=2,
        color="0.5",
        log=log,
        label="all",
        range=(0, 100),
    )
    plt.hist(
        mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.all_burst_number"][mask_ecel],
        histtype="step",
        bins=100,
        log=log,
        label="ecel",
        color="C0",
        range=(0, 100),
    )
    plt.hist(
        mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.all_burst_number"][
            mask_spp & ~mask_runaway
        ],
        histtype="step",
        bins=100,
        log=log,
        label="spp",
        color="C1",
        range=(0, 100),
    )

    plt.hist(
        mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.all_burst_number"][mask_runaway],
        histtype="step",
        bins=100,
        log=log,
        label="runaway",
        range=(0, 100),
        color="C2",
    )
    plt.hist(
        mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.all_burst_number"][mask_tonic],
        histtype="step",
        bins=100,
        color="C3",
        log=log,
        label="tonic",
        range=(0, 100),
    )
    plt.legend(loc="best")
    plt.xlabel("burst number")
    plt.tight_layout()
    plt.savefig("hist_all_burst_number.pdf")
